{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPyTorch Classification Tutorial\n",
    "## Introduction\n",
    "\n",
    "This example is the simplest form of using an RBF kernel in an `AbstractVariationalGP` module for classification. This basic model is usable when there is not much training data and no advanced techniques are required.\n",
    "\n",
    "In this example, we’re modeling a unit wave with period 1/2 centered with positive values @ x=0. We are going to classify the points as either +1 or -1.\n",
    "\n",
    "Variational inference uses the assumption that the posterior distribution factors multiplicatively over the input variables. This makes approximating the distribution via the KL divergence possible to obtain a fast approximation to the posterior. For a good explanation of variational techniques, sections 4-6 of the following may be useful: https://www.cs.princeton.edu/courses/archive/fall11/cos597C/lectures/variational-inference-i.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import gpytorch\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up training data\n",
    "\n",
    "In the next cell, we set up the training data for this example. We'll be using 15 regularly spaced points on [0,1] which we evaluate the function on and add Gaussian noise to get the training labels. Labels are unit wave with period 1/2 centered with positive values @ x=0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_x = torch.linspace(0, 1, 10)\n",
    "train_y = torch.sign(torch.cos(train_x * (4 * math.pi))).add(1).div(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the classification model\n",
    "\n",
    "The next cell demonstrates the simplist way to define a classification Gaussian process model in GPyTorch. If you have already done the [GP regression tutorial](../01_Simple_GP_Regression/Simple_GP_Regression.ipynb), you have already seen how GPyTorch model construction differs from other GP packages. In particular, the GP model expects a user to write out a `forward` method in a way analogous to PyTorch models. This gives the user the most possible flexibility.\n",
    "\n",
    "Since exact inference is intractable for GP classification, GPyTorch approximates the classification posterior using **variational inference.** We believe that variational inference is ideal for a number of reasons. Firstly, variational inference commonly relies on gradient descent techniques, which take full advantage of PyTorch's autograd. This reduces the amount of code needed to develop complex variational models. Additionally, variational inference can be performed with stochastic gradient decent, which can be extremely scalable for large datasets.\n",
    "\n",
    "If you are unfamiliar with variational inference, we recommend the following resources:\n",
    "- [Variational Inference: A Review for Statisticians](https://arxiv.org/abs/1601.00670) by David M. Blei, Alp Kucukelbir, Jon D. McAuliffe.\n",
    "- [Scalable Variational Gaussian Process Classification](https://arxiv.org/abs/1411.2005) by James Hensman, Alex Matthews, Zoubin Ghahramani.\n",
    "\n",
    "### The necessary classes\n",
    "\n",
    "For most variational GP models, you will need to construct the following GPyTorch objects:\n",
    "\n",
    "1. A **GP Model** (`gpytorch.models.AbstractVariationalGP`) -  This handles basic variational inference.\n",
    "1. A **Variational distribution** (`gpytorch.variational.VariationalDistribution`) - This tells us what form the variational distribution q(u) should take.\n",
    "1. A **Variational strategy** (`gpytorch.variational.VariationalStrategy`) - This tells us how to transform a distribution q(u) over the inducing point values to a distribution q(f) over the latent function values for some input x.\n",
    "1. A **Likelihood** (`gpytorch.likelihoods.BernoulliLikelihood`) - This is a good likelihood for binary classification\n",
    "1. A **Mean** - This defines the prior mean of the GP.\n",
    "  - If you don't know which mean to use, a `gpytorch.means.ConstantMean()` is a good place to start.\n",
    "1. A **Kernel** - This defines the prior covariance of the GP.\n",
    "  - If you don't know which kernel to use, a `gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())` is a good place to start.\n",
    "1. A **MultivariateNormal** Distribution (`gpytorch.distributions.MultivariateNormal`) - This is the object used to represent multivariate normal distributions.\n",
    "  \n",
    "  \n",
    "#### The GP Model\n",
    "  \n",
    "The `AbstractVariationalGP` model is GPyTorch's simplist approximate inference model. It approximates the true posterior with a distribution specified by a `VariationalDistribution`, which is most commonly some form of MultivariateNormal distribution. The model defines all the variational parameters that are needed, and keeps all of this information under the hood.\n",
    "\n",
    "The components of a user built `AbstractVariationalGP` model in GPyTorch are:\n",
    "\n",
    "1. An `__init__` method that constructs a mean module, a kernel module, a variational distribution object and a variational strategy object. This method should also be responsible for construting whatever other modules might be necessary.\n",
    "\n",
    "2. A `forward` method that takes in some $n \\times d$ data `x` and returns a MultivariateNormal with the *prior* mean and covariance evaluated at `x`. In other words, we return the vector $\\mu(x)$ and the $n \\times n$ matrix $K_{xx}$ representing the prior mean and covariance matrix of the GP.\n",
    "\n",
    "(For those who are unfamiliar with GP classification: even though we are performing classification, the GP model still returns a `MultivariateNormal`. The likelihood transforms this latent Gaussian variable into a Bernoulli variable)\n",
    "\n",
    "Here we present a simple classification model, but it is posslbe to construct more complex models. See some of the [scalable classification examples](../07_Scalable_GP_Classification_Multidimensional/KISSGP_Kronecker_Classification.ipynb) or [deep kernel learning examples](../08_Deep_Kernel_Learning/Deep_Kernel_Learning_DenseNet_CIFAR_Tutorial.ipynb) for some other examples."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gpytorch.models import AbstractVariationalGP\n",
    "from gpytorch.variational import CholeskyVariationalDistribution\n",
    "from gpytorch.variational import VariationalStrategy\n",
    "\n",
    "\n",
    "class GPClassificationModel(AbstractVariationalGP):\n",
    "    def __init__(self, train_x):\n",
    "        variational_distribution = CholeskyVariationalDistribution(train_x.size(0))\n",
    "        variational_strategy = VariationalStrategy(self, train_x, variational_distribution)\n",
    "        super(GPClassificationModel, self).__init__(variational_strategy)\n",
    "        self.mean_module = gpytorch.means.ConstantMean()\n",
    "        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())\n",
    "\n",
    "    def forward(self, x):\n",
    "        mean_x = self.mean_module(x)\n",
    "        covar_x = self.covar_module(x)\n",
    "        latent_pred = gpytorch.distributions.MultivariateNormal(mean_x, covar_x)\n",
    "        return latent_pred\n",
    "\n",
    "\n",
    "# Initialize model and likelihood\n",
    "model = GPClassificationModel(train_x)\n",
    "likelihood = gpytorch.likelihoods.BernoulliLikelihood()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Model modes\n",
    "\n",
    "Like most PyTorch modules, the `ExactGP` has a `.train()` and `.eval()` mode.\n",
    "- `.train()` mode is for optimizing variational parameters model hyperameters.\n",
    "- `.eval()` mode is for computing predictions through the model posterior."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Learn the variational parameters (and other hyperparameters)\n",
    "\n",
    "In the next cell, we optimize the variational parameters of our Gaussian process.\n",
    "In addition, this optimization loop also performs Type-II MLE to train the hyperparameters of the Gaussian process.\n",
    "\n",
    "The most obvious difference here compared to many other GP implementations is that, as in standard PyTorch, the core training loop is written by the user. In GPyTorch, we make use of the standard PyTorch optimizers as from `torch.optim`, and all trainable parameters of the model should be of type `torch.nn.Parameter`. The variational parameters are predefined as part of the `VariationalGP` model.\n",
    "\n",
    "In most cases, the boilerplate code below will work well. It has the same basic components as the standard PyTorch training loop:\n",
    "\n",
    "1. Zero all parameter gradients\n",
    "2. Call the model and compute the loss\n",
    "3. Call backward on the loss to fill in gradients\n",
    "4. Take a step on the optimizer\n",
    "\n",
    "However, defining custom training loops allows for greater flexibility. For example, it is possible to learn the variational parameters and kernel hyperparameters with different learning rates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Iter 1/50 - Loss: 0.905\n",
      "Iter 2/50 - Loss: 5.393\n",
      "Iter 3/50 - Loss: 8.450\n",
      "Iter 4/50 - Loss: 3.792\n",
      "Iter 5/50 - Loss: 6.634\n",
      "Iter 6/50 - Loss: 7.030\n",
      "Iter 7/50 - Loss: 6.352\n",
      "Iter 8/50 - Loss: 4.972\n",
      "Iter 9/50 - Loss: 3.931\n",
      "Iter 10/50 - Loss: 3.196\n",
      "Iter 11/50 - Loss: 2.832\n",
      "Iter 12/50 - Loss: 2.654\n",
      "Iter 13/50 - Loss: 2.448\n",
      "Iter 14/50 - Loss: 2.242\n",
      "Iter 15/50 - Loss: 1.954\n",
      "Iter 16/50 - Loss: 1.772\n",
      "Iter 17/50 - Loss: 1.581\n",
      "Iter 18/50 - Loss: 1.569\n",
      "Iter 19/50 - Loss: 1.478\n",
      "Iter 20/50 - Loss: 1.451\n",
      "Iter 21/50 - Loss: 1.443\n",
      "Iter 22/50 - Loss: 1.443\n",
      "Iter 23/50 - Loss: 1.443\n",
      "Iter 24/50 - Loss: 1.435\n",
      "Iter 25/50 - Loss: 1.420\n",
      "Iter 26/50 - Loss: 1.398\n",
      "Iter 27/50 - Loss: 1.372\n",
      "Iter 28/50 - Loss: 1.344\n",
      "Iter 29/50 - Loss: 1.313\n",
      "Iter 30/50 - Loss: 1.280\n",
      "Iter 31/50 - Loss: 1.246\n",
      "Iter 32/50 - Loss: 1.213\n",
      "Iter 33/50 - Loss: 1.182\n",
      "Iter 34/50 - Loss: 1.153\n",
      "Iter 35/50 - Loss: 1.126\n",
      "Iter 36/50 - Loss: 1.103\n",
      "Iter 37/50 - Loss: 1.085\n",
      "Iter 38/50 - Loss: 1.069\n",
      "Iter 39/50 - Loss: 1.055\n",
      "Iter 40/50 - Loss: 1.041\n",
      "Iter 41/50 - Loss: 1.025\n",
      "Iter 42/50 - Loss: 1.010\n",
      "Iter 43/50 - Loss: 0.993\n",
      "Iter 44/50 - Loss: 0.976\n",
      "Iter 45/50 - Loss: 0.958\n",
      "Iter 46/50 - Loss: 0.943\n",
      "Iter 47/50 - Loss: 0.928\n",
      "Iter 48/50 - Loss: 0.915\n",
      "Iter 49/50 - Loss: 0.903\n",
      "Iter 50/50 - Loss: 0.892\n"
     ]
    }
   ],
   "source": [
    "from gpytorch.mlls.variational_elbo import VariationalELBO\n",
    "\n",
    "# Find optimal model hyperparameters\n",
    "model.train()\n",
    "likelihood.train()\n",
    "\n",
    "# Use the adam optimizer\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=0.1)\n",
    "\n",
    "# \"Loss\" for GPs - the marginal log likelihood\n",
    "# num_data refers to the amount of training data\n",
    "mll = VariationalELBO(likelihood, model, train_y.numel())\n",
    "\n",
    "training_iter = 50\n",
    "for i in range(training_iter):\n",
    "    # Zero backpropped gradients from previous iteration\n",
    "    optimizer.zero_grad()\n",
    "    # Get predictive output\n",
    "    output = model(train_x)\n",
    "    # Calc loss and backprop gradients\n",
    "    loss = -mll(output, train_y)\n",
    "    loss.backward()\n",
    "    print('Iter %d/%d - Loss: %.3f' % (i + 1, training_iter, loss.item()))\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Make predictions with the model\n",
    "\n",
    "In the next cell, we make predictions with the model. To do this, we simply put the model and likelihood in eval mode, and call both modules on the test data.\n",
    "\n",
    "In `.eval()` mode, when we call `model()` - we get GP's latent posterior predictions. These will be MultivariateNormal distributions. But since we are performing binary classification, we want to transform these outputs to classification probabilities using our likelihood.\n",
    "\n",
    "When we call `likelihood(model())`, we get a `torch.distributions.Bernoulli` distribution, which represents our posterior probability that the data points belong to the positive class.\n",
    "\n",
    "```python\n",
    "f_preds = model(test_x)\n",
    "y_preds = likelihood(model(test_x))\n",
    "\n",
    "f_mean = f_preds.mean\n",
    "f_samples = f_preds.sample(sample_shape=torch.Size((1000,))\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQgAAADDCAYAAAB+ro88AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvOIA7rQAAFUFJREFUeJzt3V9sG1W+B/Cv86dxulC7CRuCgoTsInGRkIiTs1rYJ5o4EurL3YaUsuguQhs10oJ2XyqgqJWiSjQpcLPassAVuTFa6aKsoo3St4pdSHa1aKtb7rQOT+WhtcVTQqCpm7KQ/74PnnGcZI499tjjOZPvR6oce+zj02P7N+ecOTM/XzqdBhGRmZpqV4CI3IsBgoikGCCISIoBgoikGCCISKrObgFCiAH9z0Oapr1msr0PQApAh6Zpb9l9PyJyjq0ehBAiCuBTTdNGAYT1+7nbOwBA07RPAaSM+0SkBrtDjDAAIygk9Pu5jiPTezC2R0FEyrA1xNB7DoYOABM7nhIEsJhzv9nO+xGRs2zPQQDZocQ1TdOulVrGqVOnuKSTqErOnz/vM3u8LAECQNRsghKZ4UWT/ncQwK18hZw9e7bgGy0sLKClpaXoCjrF7fUDWMdycHv9AOt1HBwclG6zfZhTCDFgHJ0wJimFEEF98wS25iXCAD61+35E5JxyHMV4UwhxUwhxO2fTNAAYQw79eSk7QxAicp7dScpPARw0ebwz5+/RnduJ8llfX8fdu3dx9+5duPVs483NTSwtLVW7GnntrKPP50NDQwNaW1tRV2ftp1+uOQiispmfn0cgEEBzczN8PtO5s6pbW1tDfX19tauR1846ptNppFIpzM/P48EHH7RUBpdak+usrKzgwIEDVQ8O8Xgc8Xi84u+TSqUwNTVV8ffx+XwIBoNYWVmx/BoGCHKddDptOTjMzc0hGo1ifn6+5PeLx+MYGxvD9PQ0xsbGkEgkAACBQACTk5Mll2tVMBg0fZ94PI5HH30UU1NTmJqawsjISLZuZvJtM/h8vqKGbRxikNKGh4dx+fJlDA0N4Z133in69alUCm+//TbGx8ezjz3//PMYHx9HU1NTnleW18GDu6byEIlEEAqF0Nvbm33syJEjuHTp0q7nJhIJxGIxnDt3rqz1YoAgJQWDQSwvL2fvj46OYnR0FH6/H6lUKs8rt5ucnERXV9e2xw4ePIjp6Wl0dnYiHo9jenoas7Oz6O/vx9WrVwEAV69eRV9fH2ZmZtDU1IRQKIRkMonJyUmEQiE88sgj+PjjjzE+Po6XX34ZJ0+eBIBtzw+FQojFYmhvb8e1a9YO8AWDwWxPYWZmBgDQ1dWF2dlZJJNJxONxBAIBzMzMYGNjAz09PQiHd54BYR2HGKSk69ev4/jx42hsbAQANDY24rnnnsOXX35ZdFl37tyRbotEIuju7kZ7eztisRhmZ2cxMzODw4cP48yZM+js7MwGh66uLhw8eBDnzp3DCy+8kC2jt7cX4XB41/NPnz6No0ePoru7G6FQqKg6h8NhNDU1oampCRcvXkRXVxdCoRAikciubXYwQJCSHnjgARw4cAArKyvw+/3Zic3W1taiyunq6sr2CgzJZBLd3d3bHjOGG0ePHkV/fz9GRkawurqKQCCASCSS7YUEg8Hsa7q6ujAyMoLOzuxR/13PL1YqlUI4HMbIyAgCgQDa29uzjwOZoYax7fHHH9+2rRQcYpCyFhYWcOLECfT39yMWi5U0URkOh/HKK69gbGwMoVAIs7OzePfdd7PbU6nUtiGGMSQ4fPgwenp6EIvFsntvo4ufSqUQDAbR19eH06dPZ4PGG2+8se35J0+exMWLF9He3p59bSQSyb53PB5HMpnMHuFIJpPZuhnvd+fOHSQSCdy+fRupVArJZDK7bXFxEYlEAslkclu5xfC5ZSHKqVOn0jwXwxlur+ONGzfw0EMPuXqdgYrrIAw3btzAww8/nL0/ODgoPVmLQwwikmKAICIpBggikmKAICIpBggikmKAICIpBgja0+LxOJ588sltZ20mEoldj+1VXChFruX3N5StrOVl81OcI5FIdqHUe++9ByCz9NpYtrzXlaUHkS8hjhDiTf12QPYcomoKBALSbYlEAmNjY5iamkI8Hs/e//DDD5FIJDA9PY0jR45genoap0+fdrDWzijHRWujAP6c5ykDQoibyCTOIbJseXmlbP8K6e3txdjY2K7lzjtPsNp5IlR3dzeCwSC6u7ttnfPgVrYDhH5dynw//hOaph3Sn0fkSt3d3dnTp3fKPcHK7EQos2s5eIUTk5RhIURUCPGqA+9FVJR4PI5YLIZEIpHtKRiXmovH49kTrKanp7G4uJjtSXzxxRdIJBK4dOkSkslk9qQor01sVnySMidnRo8QIsqeBLlJJBLJXk3KuKhLJBLB9evXs8/JvUqTcfGVtbU1HDt2DEDmClQATK/0pLqKBgh9YnJR07RJZLJq5b20zcLCQsEy3T7Oc3v9APfXcXNzExsbG9WuRl5urx8gr+Pm5qal3xpQoQAhhAhqmpYCoGFrfuIQgA/yvc7qKchuPlUZcH/9AHfXcWlpCbW1ta4/ndrt9QPM61hTU2P58y/HUYy+zI3oy3k4N7PWs/q2m8ysRVYUe+Vlsq6YK4YDZehB6MOHyR2PMbMWlayhoQFLS0uuTpyjIiNxTkOD9QVoXElJrtPa2oqvvvoKd+7ccW1PYnNzEzU17j5TYWcdc1PvWcUAQa5TV1eHe++919XzJG6/bB9Qnjq6OwQSUVUxQBCRFAMEEUkxQBCRFAMEEUkxQBCRFAMEEUkxQBCRFAMEEUkxQBCRFAMEEUkxQBCRFAMEEUkxQBCRFAMEEUkxQBCRFAMEEUk5kZuzj4lziNRU0dycRuDQk+Wk8gUSK+bm5nDs2DHMz8/bKca03Gg0WvZyVVXJ9mBbb6l0O5fjt1Lp3JzHARhZWhIAonbea3h4GJ9//jmGhobsFGNa7uXLl8terqoq2R5s6y2Vbudy/FZ85bhqsBDiE03Tekwe/wDAB5qmXdN7Gj2apr1mVsapU6fSZ8+eNS0/GAxieTkE4B+7tjU3N5dc71u3bkm37Sy3rS2Nv/xlDYXytLrpYqbpNPDMM3W4cmX7fkB2ReZi2qNYxZadW8eGBuB3v1vHz3++aasO5WTlc/7jH2swOFiH9fXtjzvXzv8A8AwAwO/3SzOqDQ4O4vz586b5BVx1VWtZOrDPPvsMr776P/jb3+7btS1PW1uwuzxZubdu+fDJJ0t46qnlvCW6Ka3dN9/U4NKlNpMttZJXWG+P4hVb9vY6/ulPq/jZzxbtVqJsrHzO4+P34euvzbJvOdXOB+D3+/H000/jzJkzltPt5ap0gEgBaNL/DiKTn1NKFpFbWloQDq/g739vQX19PdbW1vDLX/4HhoaGbVfw9ddfx0cffYR9+/ZhdXXVtNxf/aoef/1rDRoagmhpKbwXc0sP4l//yty2taVx5cpq9vFvv/0W991n/iW10h6lKqZso46XLtVgYKAe6bTfNe1qKFSf9fVMcJicXMMTT2z/3jjRzvX1wOrqKlpaWvDYY4+VVFalc3NOABD6w2EAJWf2/vbbeQwM9OLo0aO4ePEi5ucTkHzHi7K0lMDAQC/6+/sRi8VMyw0EMsOwH36w/35OWl7O9BrvuSe97f+0ubkpbTsr7VGqYso26nj//Zn7P/ygXoatZb2z2dqa3vX/dKKdt34rpU9U2g4Qubk59TR8QCY3Z6c+9yD0+YeUndycExMTADLDkKeeesputXeVCwAXLlwwfU5jY+b2++/L9raOMOpr1N8KK+1RqlLK9vszwVm1tgfyt78T7VyO3wpzc1qwf3/mdjn/9IPrGD2eYgKE26ja9sBWr6ex0Z3pA63gSkoL9u839mJqdXONAGH8yFRk1F3FHoQX2p8BwgK/P3Or2hyEF/ZgxhBDxTkIL/TgGCAsMPYAqgWIUuYg3EbVtge22p89CI8zfmCq7cW8sAfbavvq1qNYa2vAxoYPtbVp1JsthVAEA4QFRhddtXGwF8bAqs5BeKH3ADBAWKLuYU715yDq64GamjTW131YW6t2bazzwvAOYICwRNVDbV4YYvh8as5DeKHtAQYIS7aGGGrOQajezVWxB2fMVxmHyFXFAGGBqhNlW4c5q1wRm1Rsf/Yg9hAVv6BA7jhY7b2YUX+VjiJxDmIPUXEMDHhnL6Zi+xvzVaq3PQOEBZyDqC4V5yCM7wrnIPYA4wuq3lEM9Q9zAmq2vxGcjWX6qmKAsED1xTqqd3NVPFmOC6X2kNwubhku4ekYrwwxjL2wSgGahzn3kPp6oK4ujc1NtVbzcZKyerzS9gwQFqk8UeaVOQge5nQeA4RFKu7FvHKozeims+2dxwBhkYrjYK9MlKnZ9pyDAFA496YQ4k39dsDue1WT8UEbV4p2u3TaO+NgFU+W42FOWM69OSCEuAl5ej4lqHao07hgSV2d2hcsAdQ8zOmVI0h2r2p9HMAn+t9G7s2dl7Y/kXM5fGWpdl1Kr3xBAfXaHvBO+9sNEEEAufnQzBILhvW8GB2apr2VrzArqcGqldquru4+AI2Ym7uDhQV5X9ctqfe+/roGQBsaGjZ3tatb6phPbh3X1/cDaMbt28tYWHBH+r1CbZhK/RhALVZWbmNhYcWZSu2qg/3PueK5OY2gIIToEUJE9eGIKaup1aqRgi0QyDTVvn2F0++5IUXc3buZ2x/9qMa0Pm6oYyFGHe+/PzMS3tx0V/q9fHUx0u498EAQLS3Vm6i02152Jynz5t4UQgzombegbwvbfL+qUe2Ub2MyVfU1EEDuYU515iB4mDNjAls/+mzuTSFEUH9Mw1Y+zkP6fSWptlDKK4c4AVUPc2ZuVW9/WwHCyLVpkntzOmf7s3ov4qad3JzVptphTq8cZgNUPcyZ+Z4YiX9UVY7cnLtyb3otNyegXg9i62ShKlekDFRre8A7RzG4ktIi1eYgtr6gau/BADUvOccAsceoFiCMva0XhhiqtX06zZO19hzVZtK9sgcD1DtRbnUVSKd92LcvjdraatfGHgYIi1QbB3vlcnOAehfs8UrvAWCAsEy16yJ6qQeh2gV7vNT2DBAWqXaylpfmIAC12t9Lh5gZICwyjmerckahl/ZigFpDPK9cCwJggLBMtYkyr1w01aDSkQwvBWcGCItU+oIC3htiqLQWwkttzwBhkWo9CGMy1Qt7MUCt9jeW43uh98YAYZFqcxBeOtQGqHXClpfangHCItVOGPLSOgggt/3dH6C9ci1QgAHCMpUOswHe2osBudelrHJFLPBS2zNAWJQ7SanCaj6vzUGodF1KzkHsQbW1wL59aaTTPqxU5xKDRfHKBUsMKvXg2IPYo1RcrKP6BUsMKh7mZIDYY9Q61Ja59VoPgm3vLAaIImyNg7kXc5pahzm9cwTJidR7eberpJQksnNzc4hGo5ifny9rXfKV66W0e4ZCh5kr1c6llO2ltq9o6j2LqfmUUcocxPDwMC5fvoyhoaGy1iVfuSsr3rlgicHYG8sWqlWqnUsp20sBotKp96yk5lOG8YFfuFCLtjbz7uP33wexf38t3n//v7CxsQ7g3wD8J0ZHgdHR91FbW4eXXvp1yXWwUu7KincuWGsw2v6zz3x45ZWtqFepdi5U9osvvoj9+82j77Vrmf2uF9q/0qn3rKTmy3Jz6j0ACAabAezH1FS+3fK9+u1vTLdubAB/+IOdWlgvt7l53bRNVUu9BwB+vx/Aj3H9eg2uX8/t+FaqnfOXHYsVfnV9/SIWFlbtVqJkSqTeK4abU+8BwO9/Dxw+vIb1dfkk5Xff3cU992SCxNTUFK5c+V/U1tZhY2MdP/3pE+jt7bVdD6vldnenpW3lphR2Mrl17OsD0uk1fPPN7ravVDvnKzv3czbT1pZGT08QvirPZ9v9nO0GiLyp9yxsV0pbG/DSS/nzci4sfIeWlkzf8p//HMfAQCv6+/sRi8UwPz+O3/72323Xo1LlulldHfCLX5i3fSXbQ1Z27ufsZXYDxAQAof+9LfWepmkp2fa9YmJiIvv3hQsXXF+uqirZHnu9rZ1IvWe2nYgUwNR7RCTFlZREJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJMUAQURSDBBEJOVEbs439dsBu+9FRM6qaG5O3YAQ4iYyqfeISCGVzs0JACc0TZu0+T5EVAWVzs0JAGE9L0aHpmlv5SvM7bk5rXB7/QDWsRzcXj9AkdycRlAQQvQIIaL6cMSU23NzWuX2+gGsYzm4vX6AA7k5JZOLCWPeAXlyb+qvXdSHGLeQSb9HRIooGCAKZMYqlJtTw9bk5CEAH5ReVSJymhO5OZ8VQvQBuMncnERqYW5OIpLiSkoikmKAICIpBggikmKAICIpBggikmKAICIpBggikmKAICIpBggikmKAICIpBggikmKAICIpBggikmKAICIpBggikmKAICIpBggikipLgJAkzDG25c28RUTuVY7Ue1EAf5Zss5J5i4hcynaA0H/8srR6x5G5ND6wlXmLiBRR6TkIK5m3iMilKp5ZqxiDg4PVrgIR5bCbWauQvJm3cp0/f95noTwicpDdzFqmcjJrmWbeIiI1lOMoRl/mRvTlPJybWcss8xYRKcCXTqerXQcicimupPQ4qwvVuJBNbZVarOiqoxi59CFLCkCHpmlvFbvdCRbqaEzwHtI07TVHK4ftC9WEEGEhRIfZME8fAvYAcLwdLbRhBzLzV9A0bdLh6hl1sPpdDFcjF63++X0A4JDJNkvfARlX9iAKrcB0wwpNC3WMAvhU/8KE9ftOc/VCNYuf4+t6YAi79HPuwNZRvUQ16ljJxYquDBAo/J9ywxe/UB3COY8l9PtOK7hQTd+jVOvoUt421PfM/wcAmqa9VaVJbivftTf127ALJ+JtLVZ0a4Ao9J9ywwrNvHXQNG00p7vZAUBzqmJFair8lIop9Dn+BECzEKKjinMkhT7na8j0HG7veJ4nuDVAeIbe5bxWpT1L3oVqVe49WHUr53B5X6EnO00IEUSmnYcB/LcQoho9xXwsL1Y049YAUeg/Zes/XSZW6xCtxgSlbgJbQ5vsQjX9Sw1kxvV9+mRqUxXGz4Xa8Ba2xtYpZHoUTitUxwEAw/rk5QkArghiOZ+x6XfAKrcGiEJfbFv/6TIpVEcIIQaMWe9qTFLmWahmLGSbzDkyEDQpotIKteFkzvYg9PkIhxX8nA16W6Z2Pl5plVys6NqFUvpeLYGcQ0dCiKuapnXKtrupjjnXyVhEZg90TIHuvOMsfs6LAH5SrZ6YhTq+qm9vqtZ3sVJcGyCIqPrcOsQgIhdggCAiKQYIIpJigCAiKQYIIpJigCAiKQYIIpL6f+n9W6Tl4ReeAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 288x216 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Go into eval mode\n",
    "model.eval()\n",
    "likelihood.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Test x are regularly spaced by 0.01 0,1 inclusive\n",
    "    test_x = torch.linspace(0, 1, 101)\n",
    "    # Get classification predictions\n",
    "    observed_pred = likelihood(model(test_x))\n",
    "\n",
    "    # Initialize fig and axes for plot\n",
    "    f, ax = plt.subplots(1, 1, figsize=(4, 3))\n",
    "    ax.plot(train_x.numpy(), train_y.numpy(), 'k*')\n",
    "    # Get the predicted labels (probabilites of belonging to the positive class)\n",
    "    # Transform these probabilities to be 0/1 labels\n",
    "    pred_labels = observed_pred.mean.ge(0.5).float()\n",
    "    ax.plot(test_x.numpy(), pred_labels.numpy(), 'b')\n",
    "    ax.set_ylim([-1, 2])\n",
    "    ax.legend(['Observed Data', 'Mean', 'Confidence'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
